{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "(GS_ETE)=\n",
    "\n",
    "# **Getting Started: End to End**\n",
    "\n",
    "NVIDIA(R) TensorFlow 2.x Quantization toolkit provides a simple API to quantize a given Keras model. At a higher level, Quantization Aware Training (QAT) is a three-step workflow as shown below:\n",
    "\n",
    "```{eval-rst}\n",
    ".. mermaid::\n",
    "\n",
    "    flowchart LR\n",
    "        id1(Pre-trained model) --> id2(Quantize) --> id3(Fine-tune)\n",
    "\n",
    "```\n",
    "Initially, the network is trained on the target dataset until fully converged. The Quantization step consists of inserting Q/DQ nodes in the pre-trained network to simulate quantization during training. Note that simply adding Q/DQ nodes will result in reduced accuracy since the quantization parameters are not yet updated for the given model. The network is then re-trained for a few epochs to recover accuracy in a step called \"fine-tuning\".\n",
    "\n",
    "```{eval-rst}\n",
    "\n",
    ".. admonition:: Goal\n",
    "    :class: note\n",
    "\n",
    "    #. Train a simple network on the Fashion MNIST dataset and save it as the baseline model.\n",
    "    #. Quantize the pre-trained baseline network.\n",
    "    #. Fine-tune the quantized network to recover accuracy and save it as the QAT model.\n",
    "\n",
    "```\n",
    "---\n",
    "\n",
    "## 1. Train\n",
    "Import required libraries and create a simple network with convolution and dense layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"original\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " nn_input (InputLayer)       [(None, 28, 28)]          0         \n",
      "                                                                 \n",
      " reshape_0 (Reshape)         (None, 28, 28, 1)         0         \n",
      "                                                                 \n",
      " conv_0 (Conv2D)             (None, 26, 26, 126)       1260      \n",
      "                                                                 \n",
      " relu_0 (ReLU)               (None, 26, 26, 126)       0         \n",
      "                                                                 \n",
      " conv_1 (Conv2D)             (None, 24, 24, 64)        72640     \n",
      "                                                                 \n",
      " relu_1 (ReLU)               (None, 24, 24, 64)        0         \n",
      "                                                                 \n",
      " conv_2 (Conv2D)             (None, 22, 22, 32)        18464     \n",
      "                                                                 \n",
      " relu_2 (ReLU)               (None, 22, 22, 32)        0         \n",
      "                                                                 \n",
      " conv_3 (Conv2D)             (None, 20, 20, 16)        4624      \n",
      "                                                                 \n",
      " relu_3 (ReLU)               (None, 20, 20, 16)        0         \n",
      "                                                                 \n",
      " conv_4 (Conv2D)             (None, 18, 18, 8)         1160      \n",
      "                                                                 \n",
      " relu_4 (ReLU)               (None, 18, 18, 8)         0         \n",
      "                                                                 \n",
      " max_pool_0 (MaxPooling2D)   (None, 9, 9, 8)           0         \n",
      "                                                                 \n",
      " flatten_0 (Flatten)         (None, 648)               0         \n",
      "                                                                 \n",
      " dense_0 (Dense)             (None, 100)               64900     \n",
      "                                                                 \n",
      " relu_5 (ReLU)               (None, 100)               0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 10)                1010      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 164,058\n",
      "Trainable params: 164,058\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow_quantization import quantize_model\n",
    "from tensorflow_quantization import utils\n",
    "\n",
    "assets = utils.CreateAssetsFolders(\"GettingStarted\")\n",
    "assets.add_folder(\"example\")\n",
    "\n",
    "def simple_net():\n",
    "    \"\"\"\n",
    "    Return a simple neural network.\n",
    "    \"\"\"\n",
    "    input_img = tf.keras.layers.Input(shape=(28, 28), name=\"nn_input\")\n",
    "    x = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name=\"reshape_0\")(input_img)\n",
    "    x = tf.keras.layers.Conv2D(filters=126, kernel_size=(3, 3), name=\"conv_0\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_0\")(x)\n",
    "    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), name=\"conv_1\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_1\")(x)\n",
    "    x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), name=\"conv_2\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_2\")(x)\n",
    "    x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), name=\"conv_3\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_3\")(x)\n",
    "    x = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), name=\"conv_4\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_4\")(x)\n",
    "    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name=\"max_pool_0\")(x)\n",
    "    x = tf.keras.layers.Flatten(name=\"flatten_0\")(x)\n",
    "    x = tf.keras.layers.Dense(100, name=\"dense_0\")(x)\n",
    "    x = tf.keras.layers.ReLU(name=\"relu_5\")(x)\n",
    "    x = tf.keras.layers.Dense(10, name=\"dense_1\")(x)\n",
    "    return tf.keras.Model(input_img, x, name=\"original\")\n",
    "\n",
    "# create model\n",
    "model = simple_net()\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Load Fashion MNIST data and split train and test sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load Fashion MNIST dataset\n",
    "mnist = tf.keras.datasets.fashion_mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "\n",
    "# Normalize the input image so that each pixel value is between 0 to 1.\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Compile the model and train for five epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "422/422 [==============================] - 4s 8ms/step - loss: 0.5639 - accuracy: 0.7920 - val_loss: 0.4174 - val_accuracy: 0.8437\n",
      "Epoch 2/5\n",
      "422/422 [==============================] - 3s 8ms/step - loss: 0.3619 - accuracy: 0.8696 - val_loss: 0.4134 - val_accuracy: 0.8433\n",
      "Epoch 3/5\n",
      "422/422 [==============================] - 3s 8ms/step - loss: 0.3165 - accuracy: 0.8855 - val_loss: 0.3137 - val_accuracy: 0.8812\n",
      "Epoch 4/5\n",
      "422/422 [==============================] - 3s 8ms/step - loss: 0.2787 - accuracy: 0.8964 - val_loss: 0.2943 - val_accuracy: 0.8890\n",
      "Epoch 5/5\n",
      "422/422 [==============================] - 3s 8ms/step - loss: 0.2552 - accuracy: 0.9067 - val_loss: 0.2857 - val_accuracy: 0.8952\n",
      "Baseline test accuracy: 0.888700008392334\n"
     ]
    }
   ],
   "source": [
    "# Train original classification model\n",
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\"accuracy\"],\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    train_images, train_labels, batch_size=128, epochs=5, validation_split=0.1\n",
    ")\n",
    "\n",
    "# get baseline model accuracy\n",
    "_, baseline_model_accuracy = model.evaluate(\n",
    "    test_images, test_labels, verbose=0\n",
    ")\n",
    "print(\"Baseline test accuracy:\", baseline_model_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: GettingStarted/example/fp32/saved_model/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "# save TF FP32 original model\n",
    "tf.keras.models.save_model(model, assets.example.fp32_saved_model)\n",
    "\n",
    "# Convert FP32 model to ONNX\n",
    "utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.fp32_saved_model, onnx_model_path = assets.example.fp32_onnx_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 2. Quantize\n",
    "\n",
    "Full model quantization is the most basic quantization mode someone can follow. In this mode, Q/DQ nodes are inserted in all supported keras layers, with a single function call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Quantize model\n",
    "quantized_model = quantize_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Keras model summary shows all supported layers wrapped into QDQ wrapper class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"original\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " nn_input (InputLayer)       [(None, 28, 28)]          0         \n",
      "                                                                 \n",
      " reshape_0 (Reshape)         (None, 28, 28, 1)         0         \n",
      "                                                                 \n",
      " quant_conv_0 (Conv2DQuantiz  (None, 26, 26, 126)      1515      \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_0 (ReLU)               (None, 26, 26, 126)       0         \n",
      "                                                                 \n",
      " quant_conv_1 (Conv2DQuantiz  (None, 24, 24, 64)       72771     \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_1 (ReLU)               (None, 24, 24, 64)        0         \n",
      "                                                                 \n",
      " quant_conv_2 (Conv2DQuantiz  (None, 22, 22, 32)       18531     \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_2 (ReLU)               (None, 22, 22, 32)        0         \n",
      "                                                                 \n",
      " quant_conv_3 (Conv2DQuantiz  (None, 20, 20, 16)       4659      \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_3 (ReLU)               (None, 20, 20, 16)        0         \n",
      "                                                                 \n",
      " quant_conv_4 (Conv2DQuantiz  (None, 18, 18, 8)        1179      \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_4 (ReLU)               (None, 18, 18, 8)         0         \n",
      "                                                                 \n",
      " max_pool_0 (MaxPooling2D)   (None, 9, 9, 8)           0         \n",
      "                                                                 \n",
      " flatten_0 (Flatten)         (None, 648)               0         \n",
      "                                                                 \n",
      " quant_dense_0 (DenseQuantiz  (None, 100)              65103     \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      " relu_5 (ReLU)               (None, 100)               0         \n",
      "                                                                 \n",
      " quant_dense_1 (DenseQuantiz  (None, 10)               1033      \n",
      " eWrapper)                                                       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 164,791\n",
      "Trainable params: 164,058\n",
      "Non-trainable params: 733\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "quantized_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's check the quantized model's accuracy immediately after Q/DQ nodes are inserted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Quantization test accuracy immediately after QDQ insertion: 0.883899986743927\n"
     ]
    }
   ],
   "source": [
    "# Compile quantized model\n",
    "quantized_model.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(0.0001),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\"accuracy\"],\n",
    ")\n",
    "# Get accuracy immediately after QDQ nodes are inserted.\n",
    "_, q_aware_model_accuracy = quantized_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print(\"Quantization test accuracy immediately after QDQ insertion:\", q_aware_model_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The model's accuracy decreases a bit as soon as Q/DQ nodes are inserted, requiring fine-tuning to recover it.\n",
    "\n",
    "```{note}\n",
    "\n",
    "Since this is a very small model, accuracy drop is small. For standard models like ResNets, accuracy drop immediately after QDQ insertion can be significant.\n",
    "\n",
    "```\n",
    "\n",
    "## 3. Fine-tune\n",
    "Since the quantized model behaves similar to the original keras model, the same training recipe can be used for fine-tuning as well.\n",
    "\n",
    "We fine-tune the model for two epochs and evaluate the model on the test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "1688/1688 [==============================] - 26s 15ms/step - loss: 0.1793 - accuracy: 0.9340 - val_loss: 0.2468 - val_accuracy: 0.9112\n",
      "Epoch 2/2\n",
      "1688/1688 [==============================] - 25s 15ms/step - loss: 0.1725 - accuracy: 0.9373 - val_loss: 0.2484 - val_accuracy: 0.9070\n",
      "Quantization test accuracy after fine-tuning: 0.9075999855995178\n",
      "Baseline test accuracy (for reference): 0.888700008392334\n"
     ]
    }
   ],
   "source": [
    "# fine tune quantized model for 2 epochs.\n",
    "quantized_model.fit(\n",
    "    train_images, train_labels, batch_size=32, epochs=2, validation_split=0.1\n",
    ")\n",
    "# Get quantized accuracy\n",
    "_, q_aware_model_accuracy_finetuned = quantized_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print(\"Quantization test accuracy after fine-tuning:\", q_aware_model_accuracy_finetuned)\n",
    "print(\"Baseline test accuracy (for reference):\", baseline_model_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "```{note}\n",
    "\n",
    "If the network is not fully converged, the fine-tuned model's accuracy can surpass the original model's accuracy.\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as conv_0_layer_call_fn, conv_0_layer_call_and_return_conditional_losses, conv_1_layer_call_fn, conv_1_layer_call_and_return_conditional_losses, conv_2_layer_call_fn while saving (showing 5 of 14). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: GettingStarted/example/int8/saved_model/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ONNX conversion Done!\n"
     ]
    }
   ],
   "source": [
    "# save TF INT8 original model\n",
    "tf.keras.models.save_model(quantized_model, assets.example.int8_saved_model)\n",
    "\n",
    "# Convert INT8 model to ONNX\n",
    "utils.convert_saved_model_to_onnx(saved_model_dir = assets.example.int8_saved_model, onnx_model_path = assets.example.int8_onnx_model)\n",
    "\n",
    "tf.keras.backend.clear_session()"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "In this example, accuracy loss due to quantization is recovered in just two epochs.\n",
    "\n",
    "This NVIDIA(R) Quantization Toolkit provides an easy interface to create quantized networks, and thus take advantage of INT8 inference on NVIDIA(R) GPUs using TensorRT(TM)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4442e1c252d743d7d1ab28567e302ebe8a15da81acb5d7e7894db75e10bdb29d"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}